package org.zjvis.datascience.spark.algorithm;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.cli.*;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zjvis.datascience.spark.util.*;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
 * @description 机器学习算子通用父类
 * @date 2021-12-23
 */
public class BaseAlgorithm implements Serializable {

    protected final static Logger logger = LoggerFactory.getLogger("BaseAlgorithm");

    protected CacheUtil cacheUtil;

    transient protected Options options = new Options();

    transient protected CommandLineParser parser = new PosixParser();

    transient protected HelpFormatter formatter = new HelpFormatter();

    protected SparkSession sparkSession = null;

    transient protected CommandLine cmd = null;

    protected String algorithmName;

    protected String idCol;

    protected String uniqueKey;

    protected Long taskInstanceId;

    protected boolean isHive = false;

    protected boolean isSample = false;

    protected double sampleRatio = 0.2;

    protected int sampleNumber = 1000;

    protected OutputResult retResult = new OutputResult();

    public BaseAlgorithm(SparkSession sparkSession) {
        this.sparkSession = sparkSession;
        cacheUtil = new CacheUtil(sparkSession);
    }

    public String getRetResult() {
        return JSONObject.toJSONString(retResult);
    }

    public boolean parseParams(String[] args) {
        options = new Options();
        Option uk = new Option("uk", "uniqueKey", true, "unique key");
        uk.setRequired(true);
        options.addOption(uk);
        Option helper = new Option("h", "help", true, "help");
        helper.setRequired(false);
        options.addOption(helper);
        options.addOption(
                OptionHelper.getSpecOption("id", "taskInstanceId", "task instance id", true));
        options.addOption(OptionHelper.getSpecOption("idCol"));
        options.addOption(OptionHelper.getSpecOption("sample"));
        return true;
    }

    protected boolean initCommandLine(String[] args) {
        try {
            cmd = parser.parse(options, Arrays.copyOfRange(args, 1, args.length));
        } catch (ParseException e) {
            logger.error(e.getMessage(), e);
            formatter.printHelp("utility-name", options);
            return false;
        }

        if (cmd.hasOption("h")) {
            formatter.printHelp("utility-name", options);
            return false;
        }
        if (cmd.hasOption("uk") || cmd.hasOption("uniqueKey")) {
            uniqueKey = cmd.getOptionValue("uniqueKey");
        }
        if (cmd.hasOption("id") || cmd.hasOption("taskInstanceId")) {
            taskInstanceId = Long.parseLong(cmd.getOptionValue("taskInstanceId"));
        }
        if (cmd.hasOption("idcol") || cmd.hasOption("idCol")) {
            idCol = cmd.getOptionValue("idCol", UtilTool.DEFAULT_ID_COL);
        } else {
            idCol = UtilTool.DEFAULT_ID_COL;
        }
        if (cmd.hasOption("sample")) {
            String sample = cmd.getOptionValue("sample", "false");
            if (sample.equals("true")) {
                isSample = true;
            }
        }

        algorithmName = args[0];
        return true;
    }

    public boolean beginAlgorithm() {
        return true;
    }

    protected void prepareInputParamsForResult(JSONObject inputParams) {
        for (Option object : cmd.getOptions()) {
            inputParams.put(object.getLongOpt(), object.getValue());
        }
    }

    protected void prepareOutputParamsForResult(JSONArray outputParams, String tableName) {
        Map<String, String> meta = null;
        if (isHive) {
            meta = HiveTool.getTableMetaMap(sparkSession, tableName);
        } else {
            meta = UtilTool.getTableMetaForGp(tableName);
        }
        if (meta == null || meta.isEmpty()) {
            return;
        }
        JSONObject item = new JSONObject();
        item.put("out_table_name", tableName);
        List<String> cols = new ArrayList<>();
        List<String> types = new ArrayList<>();
        for (Map.Entry<String, String> entry : meta.entrySet()) {
            cols.add(entry.getKey());
            types.add(entry.getValue());
        }
        item.put("output_cols", cols);
        item.put("output_types", types);
        outputParams.add(item);
    }

    protected boolean saveHiveTable(Dataset<Row> dataset, String targetTable) {
        String tmpView = String.format("view_%s", uniqueKey);
        dataset.createOrReplaceTempView(tmpView);
        sparkSession.sql("select * from " + tmpView).write().mode("append")
                .saveAsTable(targetTable);
        return true;
    }

    protected OutputResult buildOutputResult(List<String> tableNames, int status, String errorMsg,
                                             Map<String, Object> inputKV) {
        OutputResult outputResult = new OutputResult();
        JSONObject inputParams = new JSONObject();
        JSONArray outputParams = new JSONArray();
        this.prepareInputParamsForResult(inputParams, inputKV);
        for (String tableName : tableNames) {
            this.prepareOutputParamsForResult(outputParams, tableName);
        }
        outputResult.setInputParams(inputParams);
        outputResult.setOutputParams(outputParams);
        outputResult.setStatus(status);
        if (StringUtils.isNotEmpty(errorMsg)) {
            outputResult.setSuccess(errorMsg);
        }
        return outputResult;
    }

    protected void saveOutputResultForHive(SparkSession sparkSession, OutputResult outputResult,
                                           String metaPath) {
        List<String> lists = new ArrayList<>();
        lists.add(JSONObject.toJSONString(outputResult));
        JavaSparkContext javaSparkContext = new JavaSparkContext(sparkSession.sparkContext());
        javaSparkContext.parallelize(lists).coalesce(1).saveAsTextFile(metaPath);
    }

    protected void saveOutputResultForMysql(SparkSession sparkSession, OutputResult outputResult) {
        Connection conn = UtilTool.getConn();
        if (conn != null) {
            String sql = String
                    .format("update aiworks.task_instance set log_info = '%s' where id = %s",
                            JSONObject.toJSONString(outputResult), taskInstanceId);
            logger.debug("when saving result in mysql, sql = {}", sql);
            try {
                Statement statement = conn.createStatement();
                statement.execute(sql);
            } catch (SQLException e) {
                logger.error(e.getMessage(), e);
                e.printStackTrace();
            } finally {
                try {
                    conn.close();
                } catch (SQLException e) {
                    logger.error("close mysql connection failed.");
                    logger.error(e.getMessage(), e);
                }
            }
        } else {
            logger.error("jdbc conn is null");
            System.exit(1);
        }
    }

    protected void prepareInputParamsForResult(JSONObject input, Map<String, Object> inputKV) {
        this.prepareInputParamsForResult(input);
        for (Map.Entry<String, Object> entry : inputKV.entrySet()) {
            input.put(entry.getKey(), entry.getValue());
        }
    }

}
